import os
import json
import joblib
import openai
import numpy as np
import subprocess
import tiktoken
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk

from pipelines.prompta.rag.database import Database
from pipelines.prompta.rag.chunk import split_all, split_into_chunks_by_tokens
from prompta.utils.set_api import global_openai_client
from prompta.utils.config_helper import read_config, deep_merge_dicts


def start_elasticsearch(elasticsearch_path):
    # Ensure the provided path is absolute
    elasticsearch_path = os.path.abspath(elasticsearch_path)
    
    # Verify that the file exists
    if not os.path.exists(elasticsearch_path):
        raise FileNotFoundError(f"File not found: {elasticsearch_path}")
    
    # Start Elasticsearch using subprocess
    try:
        process = subprocess.Popen([elasticsearch_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("Elasticsearch started successfully.")
        return process
    except Exception as e:
        print(f"Failed to start Elasticsearch: {e}")
        raise


default_config = read_config(os.path.join(os.path.dirname(__file__), 'rag_default_config.yaml'))


class RAGSystem:
    def __init__(self, config=None):
        self.whole_config = deep_merge_dicts(default_config, config)
        self.config = self.whole_config.rag

        self.client = global_openai_client
        self.initialize()

    def initialize(self):
        if self.config.database.load_db and os.path.exists(self.config.database.db_path):
            self.db = Database(self.config.database)
        else:
            self.db = Database(self.config.database)
            chunks = split_all(self.config.data.data_path)
            for chunk in tqdm(chunks, desc="Splitting and adding chunks"):
                self.add_chunk(chunk)
        if self.config.model.model_type == 'knn':
            if self.config.model.load_model and os.path.exists(self.config.model.knn.model_path):
                self.load_embeddings()
                self.load_knn_model()
            else:
                self.rebuild()
        elif self.config.model.model_type == 'elastic':
            self.es_proc = start_elasticsearch(self.config.model.elastic.exec_path)
            self.es = Elasticsearch(self.config.model.elastic.host)
            data = self.db.list_rows()
            for id, chunk, embedding in tqdm(data, desc="Indexing chunks"):
                self.index_chunk_in_elasticsearch(chunk, id)
            
        self.data_consistency_flag = True

    def load_embeddings(self):
        data = self.db.list_columns()
        self.embeddings = [json.loads(embedding) for (id, embedding) in data]
        self.ids = [id for (id, embedding) in data]

    def build_knn_model(self):
        self.knn_model = NearestNeighbors(n_neighbors=5, metric='cosine')
        self.knn_model.fit(self.embeddings)
        self.save_knn_model()

    def save_knn_model(self):
        save_dir = os.path.dirname(self.config.model.knn.model_path)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        joblib.dump(self.knn_model, self.config.model.knn.model_path)

    def load_knn_model(self):
        self.knn_model = joblib.load(self.config.model.knn.model_path)

    def rebuild(self):
        self.load_embeddings()
        self.build_knn_model()
        self.data_consistency_flag = True

    def generate_embedding(self, text):
        chunks = split_into_chunks_by_tokens(text, 4096)
        
        if len(chunks) > 1:
            embeddings = []
            for chunk in chunks:
                response = self.client.embeddings.create(
                    input=[chunk],
                    model="text-embedding-3-large"
                )
                embeddings.append(response.data[0].embedding)
            # Combine embeddings (e.g., average or concatenate)
            combined_embedding = self.combine_embeddings(embeddings)
            return combined_embedding
        else:
            response = self.client.embeddings.create(
                input=[text],
                model="text-embedding-3-large"
            )
            return response.data[0].embedding
        # return np.random.rand(3072).tolist()
    
    def combine_embeddings(self, embeddings):
        # Example: average embeddings
        return np.mean(embeddings, axis=0).tolist()

    def add_chunk(self, chunk):
        embedding = self.generate_embedding(chunk)
        self.db.add_chunk(chunk, embedding)
        self.data_consistency_flag = False

    def index_chunk_in_elasticsearch(self, chunk, chunk_id):
        doc = {
            "text": chunk,
            "id": chunk_id,
        }
        self.es.index(index=self.config.model.elastic.index, body=doc)

    def retrieve_chunks_from_elasticsearch(self, query, k=20):
        body = {
            "query": {
                "match": {
                    "text": query
                }
            },
            "size": k
        }
        res = self.es.search(index=self.config.model.elastic.index, body=body)
        hits = res['hits']['hits']
        return [(hit['_source']['text'], hit['_source']['id']) for hit in hits]

    def retrieve_chunks(self, query, k=5):
        if not self.data_consistency_flag and self.config.model.model_type == 'knn':
            self.rebuild()
        
        if self.config.model.model_type == 'elastic':
            es_results = self.retrieve_chunks_from_elasticsearch(query, k=20)
            es_texts = [text for text, chk_id in es_results]
            es_ids = [chk_id for text, chk_id in es_results]

            query_embedding = np.array(self.generate_embedding(query)).reshape(1, -1)
            es_embeddings = np.array([self.db.get_embedding(chk_id) for chk_id in es_ids])
            
            similarities = cosine_similarity(query_embedding, es_embeddings).flatten()
            ranked_indices = np.argsort(similarities)[::-1]

            final_results = [es_texts[idx] for idx in ranked_indices[:k]]
        elif self.config.model.model_type == 'knn':
            query_embedding = np.array(self.generate_embedding(query)).reshape(1, -1)
            distances, indices = self.knn_model.kneighbors(query_embedding, n_neighbors=k)
            final_results = [self.db.get_chunk(self.ids[idx]) for i, idx in enumerate(indices[0])]
        return final_results

    def generate_response(self, queries, k=5, response_format=None, max_tokens=4096, n=1, seed=0):
        relevant_chunks = self.retrieve_chunks(queries[-1]['content'], k)
        # print([chunk for chunk in relevant_chunks])
        context = "Here are some relevant text chunks that might be useful to answer the question:\n```\n{}```\n".format('\n'.join([chunk for chunk in relevant_chunks]))
        queries[-1]['content'] = '\n'.join((queries[-1]['content'], context))
        response = self.client.chat.completions.create(
            model="gpt-4o", # Specify the GPT-4 engine
            response_format=response_format, # {"type": "json_object"}
            messages=queries,
            max_tokens=max_tokens, # Maximum number of tokens in the response
            n=n, # Number of completions to generate
            stop=None, # Token at which to stop generating further tokens
            temperature=None, # Controls the randomness of the response
            seed=seed
        )
        return response.choices[0].message.content


class SymbolRAGSystem:
    def __init__(self, config=None):
        self.config = deep_merge_dicts(default_config, config)

        self.client = global_openai_client
        self.initialize()

    def initialize(self):
        if self.config.database.load_db and os.path.exists(self.config.database.db_path):
            self.db = Database(self.config.database)
        else:
            self.db = Database(self.config.database)
            symbol_desc_files = os.listdir(self.config.data.data_path)
            for symbol_desc_file in tqdm(symbol_desc_files, desc="Loading and adding symbols"):
                if not symbol_desc_file.endswith(self.config.data.data_extension): continue
                symbol_desc_path = os.path.join(self.config.data.data_path, symbol_desc_file)
                symbol = symbol_desc_file.split('.')[0]
                symbol_desc = open(symbol_desc_path, 'r', encoding='utf-8').read()
                self.add_symbol(symbol, symbol_desc)
        if self.config.model.model_type == 'knn':
            if self.config.model.load_model and os.path.exists(self.config.model.knn.model_path):
                self.load_embeddings()
                self.load_knn_model()
            else:
                self.rebuild()
        elif self.config.model.model_type == 'elastic':
            self.es_proc = start_elasticsearch(self.config.model.elastic.exec_path)
            self.es = Elasticsearch(self.config.model.elastic.host)
            data = self.db.list_rows()
            for id, symbol, embedding in tqdm(data, desc="Indexing chunks"):
                self.index_chunk_in_elasticsearch(symbol, id)
            
        self.data_consistency_flag = True

    def load_embeddings(self):
        data = self.db.list_rows()
        self.embeddings = [json.loads(embedding) for (id, symbol, embedding) in data]
        self.ids = [id for (id, symbol, embedding) in data]

    def build_knn_model(self):
        self.knn_model = NearestNeighbors(n_neighbors=5, metric='cosine')
        self.knn_model.fit(self.embeddings)
        self.save_knn_model()

    def save_knn_model(self):
        save_dir = os.path.dirname(self.config.model.knn.model_path)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        joblib.dump(self.knn_model, self.config.model.knn.model_path)

    def load_knn_model(self):
        self.knn_model = joblib.load(self.config.model.knn.model_path)

    def rebuild(self):
        self.load_embeddings()
        self.build_knn_model()
        self.data_consistency_flag = True

    def generate_embedding(self, text):
        chunks = split_into_chunks_by_tokens(text, 4096)
        
        if len(chunks) > 1:
            embeddings = []
            for chunk in chunks:
                response = self.client.embeddings.create(
                    input=[chunk],
                    model="text-embedding-3-large"
                )
                embeddings.append(response.data[0].embedding)
            # Combine embeddings (e.g., average or concatenate)
            combined_embedding = self.combine_embeddings(embeddings)
            return combined_embedding
        else:
            response = self.client.embeddings.create(
                input=[text],
                model="text-embedding-3-large"
            )
            return response.data[0].embedding
        # return np.random.rand(3072).tolist()
    
    def combine_embeddings(self, embeddings):
        # Example: average embeddings
        return np.mean(embeddings, axis=0).tolist()

    def add_symbol(self, symbol, symbol_desc):
        embedding = self.generate_embedding(symbol_desc)
        embedding = json.dumps(embedding)
        self.db.add_row(symbol, embedding)
        self.data_consistency_flag = False

    def index_chunk_in_elasticsearch(self, chunk, chunk_id):
        doc = {
            "text": chunk,
            "id": chunk_id,
        }
        self.es.index(index=self.config.model.elastic.index, body=doc)

    def retrieve_chunks_from_elasticsearch(self, query, k=20):
        body = {
            "query": {
                "match": {
                    "text": query
                }
            },
            "size": k
        }
        res = self.es.search(index=self.config.model.elastic.index, body=body)
        hits = res['hits']['hits']
        return [(hit['_source']['text'], hit['_source']['id']) for hit in hits]

    def retrieve_symbols(self, query, k=20):
        if not self.data_consistency_flag and self.config.model.model_type == 'knn':
            self.rebuild()
        
        if self.config.model.model_type == 'elastic':
            es_results = self.retrieve_chunks_from_elasticsearch(query, k=20)
            es_texts = [text for text, chk_id in es_results]
            es_ids = [chk_id for text, chk_id in es_results]

            query_embedding = np.array(self.generate_embedding(query)).reshape(1, -1)
            es_embeddings = np.array([self.db.get_embedding(chk_id) for chk_id in es_ids])
            
            similarities = cosine_similarity(query_embedding, es_embeddings).flatten()
            ranked_indices = np.argsort(similarities)[::-1]

            final_results = [es_texts[idx] for idx in ranked_indices[:k]]
        elif self.config.model.model_type == 'knn':
            query_embedding = np.array(self.generate_embedding(query)).reshape(1, -1)
            distances, indices = self.knn_model.kneighbors(query_embedding, n_neighbors=k)
            final_results = [self.db.get_object_by_id('symbol', self.ids[idx]) for i, idx in enumerate(indices[0])]
        return final_results

    def generate_response(self, queries, k=20, response_format=None, max_tokens=128, n=1, seed=0):
        relevant_symbols = self.retrieve_symbols(queries[-1]['content'], k)
        # print([chunk for chunk in relevant_chunks])
        context = "\n```\n[{}]```\n".format(', '.join([symbol for symbol in relevant_symbols]))
        queries[-1]['content'] = '\n'.join((queries[-1]['content'], context))
        response = self.client.chat.completions.create(
            model="gpt-4o", # Specify the GPT-4 engine
            response_format=response_format, # {"type": "json_object"}
            messages=queries,
            max_tokens=max_tokens, # Maximum number of tokens in the response
            n=n, # Number of completions to generate
            stop=None, # Token at which to stop generating further tokens
            temperature=None, # Controls the randomness of the response
            seed=seed
        )
        return response.choices[0].message.content


if __name__ == '__main__':
    from pipelines.prompta.agent.prompts import *
    # system = RAGSystem()
    # print(system.retrieve_chunks("""How to create a stone pickaxe?""", 5))
    # origin_goal = "create a stone pickaxe"
    # print(system.generate_response([
    #         {"role": "system", "content": MAIN_GOAL_SYS_PROMPT},
    #         {"role": "user", "content": MAIN_GOAL_USER_PROMPT.format(origin_goal)}
    #     ]))
    # main_goal = {
    #     "CraftStonePickaxe": "Craft a stone pickaxe by collecting three blocks of cobblestone or its variants (blackstone or cobbled deepslate) with a wooden pickaxe and combining them with two sticks on a crafting table."
    # }
    # print(system.generate_response([
    #     {"role": "system", "content": SUBGOAL_SYS_PROMPT},
    #     {"role": "user", "content": SUBGOAL_USER_PROMPT.format("CraftStonePickaxe", ['craftItem', "exploreUntil", "givePlacedItemBack", "mineBlock", "placeItem"])}
    # ]))
    from prompta.utils.config_helper import read_config

    config = read_config(r"E:\llm-automata\PROMPTA\pipelines\prompta\rag\rag_default_config.yaml")

    system = SymbolRAGSystem(config)

    queries = [
        {"role": "user", "content": "For this sub-goal: \"Mine[Log]: Mine a wood log from a nearby tree in the jungle biome.\", what is the most appropriate object?\nYou are currently located at position (x: 4.50, y: 90.00, z: 25.50) in a jungle biome. It is facing yaw: -3.14 and pitch: -1.57. You have health: 20, food: 20, and saturation: 5. The current time of day is day. Your velocity is (x: 0.00, y: -0.08, z: 0.00). Nearby entities include: a parrot at 19.77 blocks away, a chicken at 23.00 blocks away. You are surrounded by blocks such as stone, dirt, grass_block, coal_ore. Since the last observation, you have lost 1 of dirt."},
    ]
    print(system.generate_response(queries, 20))
